import numpy as np
import scipy
import sklearn
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y
from sklearn.model_selection import KFold
from model.rkhs.hyper_parameter import _BaseRKHSIV, _check_auto
from sklearn.preprocessing import RobustScaler as Scaler




class RKHSIV(BaseEstimator, _BaseRKHSIV):

    def __init__(self, gamma_hq=0.1, gamma_gm='auto', 
                    delta_scale='auto', delta_exp='auto', alpha_scale='auto'):
        """
        Parameters:
            gamma_hq : the gamma parameter for the rbf kernel of h
            gamma_gm : the gamma parameter for the rbf kernel of f
            delta_scale : the scale of the critical radius; delta_n = delta_scal / n**(delta_exp)
            delta_exp : the exponent of the cirical radius; delta_n = delta_scal / n**(delta_exp)
            alpha_scale : the scale of the regularization; alpha = alpha_scale * (delta**4)
        """
        self.gamma_gm = gamma_gm
        self.gamma_hq = gamma_hq 
        self.delta_scale = delta_scale  # worst-case critical value of RKHS spaces
        self.delta_exp = delta_exp
        self.alpha_scale = alpha_scale  # regularization strength from Theorem 5

    def fit(self,AWX, model_target , AZX, type):
        if type == 'estimate_h':
            X = AWX
            condition = AZX
            y = model_target
        else:
            X = AZX
            condition = AWX
            y = model_target
        X, y = check_X_y(X, y, accept_sparse=True)
        condition, y = check_X_y(condition, y, accept_sparse=True)

        # Standardize condition and get gamma_gm -> Kf -> RootKf
        condition = Scaler().fit_transform(condition)
        gamma_gm = self._get_gamma_gm(condition=condition)
        self.gamma_gm = gamma_gm
        Kf = self._get_kernel_gm(condition, gamma_gm=self.gamma_gm)
        RootKf = scipy.linalg.sqrtm(Kf).astype(float)

        # Standardize X and get Kh
        self.transX = Scaler()
        self.transX.fit(X)
        X = self.transX.transform(X)
        self.X = X.copy()
        Kh = self._get_kernel_hq(X, gamma_hq=self.gamma_hq)

        # delta & alpha
        n = X.shape[0]  # number of samples
        delta = self._get_delta(n)
        alpha = self._get_alpha(delta, self._get_alpha_scale())

        # M
        M = RootKf @ np.linalg.inv(
            Kf / (2 * n * delta**2) + np.eye(n) / 2) @ RootKf

        self.a = np.linalg.lstsq(Kh @ M @ Kh + alpha * Kh, Kh @ M @ y, rcond=None)[0]
        return self

    def predict(self, X):
        X = self.transX.transform(X)
        return self._get_kernel_hq(X, Y=self.X, gamma_hq=self.gamma_hq) @ self.a

class RKHSIVCV(RKHSIV):

    def __init__(self, gamma_gm='auto', gamma_hqs='auto', n_gamma_hqs=20,
                    delta_scale='auto', delta_exp='auto', alpha_scales='auto', n_alphas=30, cv=6):
        """
        Parameters:
            gamma_gm : the gamma parameter for the kernel of f
            gamma_hqs : the list of gamma parameters for kernel of h
            n_gamma_hqs : how many gamma_hqs to try
            delta_scale : the scale of the critical radius; delta_n = delta_scal / n**(delta_exp)
            delta_exp : the exponent of the cirical radius; delta_n = delta_scal / n**(delta_exp)
            alpha_scales : a list of scale of the regularization to choose from; alpha = alpha_scale * (delta**4)
            n_alphas : how many alpha_scales to try
            cv : how many folds to use in cross-validation for alpha_scale, gamma_hq
        """

        self.gamma_gm = gamma_gm
        self.gamma_hqs = gamma_hqs
        self.n_gamma_hqs=n_gamma_hqs
        self.delta_scale = delta_scale  # worst-case critical value of RKHS spaces
        self.delta_exp = delta_exp  # worst-case critical value of RKHS spaces
        self.alpha_scales = alpha_scales  # regularization strength from Theorem 5
        self.n_alphas = n_alphas
        self.cv = cv

    def _get_gamma_hqs(self,X):
        if _check_auto(self.gamma_hqs):
            params = {"squared": True}
            K_X_euclidean = sklearn.metrics.pairwise_distances(X = X, metric='euclidean', **params)
            return 1./np.quantile(K_X_euclidean[np.tril_indices(X.shape[0],-1)], np.array(range(1, self.n_gamma_hqs))/self.n_gamma_hqs)
        else:
            return self.gamma_hqs

    def fit(self,AWX, model_target , AZX, type):
        if type == 'estimate_h':
            X = AWX
            condition = AZX
            y = model_target
        else:
            X = AZX
            condition = AWX
            y = model_target
            
        X, y = check_X_y(X, y, accept_sparse=True)
        condition, y = check_X_y(condition, y, accept_sparse=True)

        # Standardize condition and get gamma_gm -> RootKf
        condition = Scaler().fit_transform(condition)
        gamma_gm = self._get_gamma_gm(condition = condition)
        Kf = self._get_kernel_gm(condition, gamma_gm=gamma_gm)
        RootKf = scipy.linalg.sqrtm(Kf).astype(float)

        # Standardize X and get gamma_hqs
        self.transX = Scaler()
        self.transX.fit(X)
        X = self.transX.transform(X)
        self.X = X.copy()
        gamma_hqs = self._get_gamma_hqs(X)

        # delta & alpha
        n = X.shape[0]
        n_train = n * (self.cv - 1) / self.cv
        delta_train = self._get_delta(n_train)
        n_test = n / self.cv
        delta_test = self._get_delta(n_test)
        alpha_scales = self._get_alpha_scales()

        # get best (alpha, gamma_hq) START
        scores = []
        for it1, (train, test) in enumerate(KFold(n_splits=self.cv).split(X)):
            # Standardize X_train
            transX = Scaler()
            X_train = transX.fit_transform(X[train])
            X_test = transX.transform(X[test])
            # Standardize condition_train and get Kf_train, RootKf_train, M_train
            condition_train = Scaler().fit_transform(condition[train])
            Kf_train = self._get_kernel_gm(X=condition_train, gamma_gm=self._get_gamma_gm(condition=condition_train))
            RootKf_train = scipy.linalg.sqrtm(Kf_train).astype(float)
            M_train = RootKf_train @ np.linalg.inv(
                Kf_train / (2 * n_train * (delta_train**2)) + np.eye(len(train)) / 2) @ RootKf_train
            # Use M_test based on precomputed RootKf to make sure evaluations are the same
            M_test = RootKf[np.ix_(test, test)] @ np.linalg.inv(
                Kf[np.ix_(test, test)] / (2 * n_test * (delta_test**2)) + np.eye(len(test)) / 2) @ RootKf[np.ix_(test, test)]
            scores.append([])
            for it2, gamma_hq in enumerate(gamma_hqs):
                Kh_train = self._get_kernel_hq(X=X_train, gamma_hq=gamma_hq)
                KMK_train = Kh_train @ M_train @ Kh_train
                B_train = Kh_train @ M_train @ y[train]
                scores[it1].append([])
                for alpha_scale in alpha_scales:
                    alpha = self._get_alpha(delta_train, alpha_scale)
                    #a = np.linalg.pinv(KMK_train + alpha * Kh_train) @ B_train
                    a = np.linalg.lstsq(KMK_train + alpha * Kh_train, B_train, rcond=None)[0]
                    res = y[test] - self._get_kernel_hq(X=X_test, Y=X_train, gamma_hq=gamma_hq) @ a
                    scores[it1][it2].append((res.T @ M_test @ res).reshape(-1)[0] / (res.shape[0]**2))

        avg_scores = np.mean(np.array(scores), axis=0)
        best_ind = np.unravel_index(np.argmin(avg_scores), avg_scores.shape)
        self.gamma_hq = gamma_hqs[best_ind[0]]
        self.best_alpha_scale = alpha_scales[best_ind[1]]
        delta = self._get_delta(n)
        self.best_alpha = self._get_alpha(delta, self.best_alpha_scale)
        # M
        M = RootKf @ np.linalg.inv(
            Kf / (2 * n * delta**2) + np.eye(n) / 2) @ RootKf
        # Kh
        Kh = self._get_kernel_hq(X, gamma_hq=self.gamma_hq)

        # self.a = np.linalg.pinv(
        #     Kh @ M @ Kh + self.best_alpha * Kh) @ Kh @ M @ y
        self.a = np.linalg.lstsq(
            Kh @ M @ Kh + self.best_alpha * Kh, Kh @ M @ y, rcond=None)[0]

        return self
    

